{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "984ff14f",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install pytorch_lightning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "1d830d65",
   "metadata": {},
   "outputs": [],
   "source": [
    "!rm -r tb_logs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "764519b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "from torch import nn\n",
    "from torch import optim\n",
    "from torchvision import transforms as T\n",
    "import pytorch_lightning as pl\n",
    "from pytorch_lightning.loggers import TensorBoardLogger\n",
    "import torchmetrics\n",
    "from torchmetrics import Metric\n",
    "import random"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "487d1aaf",
   "metadata": {},
   "source": [
    "# PyTorch Lightning Model Class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "8401bbf5",
   "metadata": {},
   "outputs": [],
   "source": [
    "class LightningModel(pl.LightningModule):\n",
    "    def __init__(self, input_size, num_classes):\n",
    "        super().__init__()\n",
    "        self.fc1 = nn.Linear(input_size, 128)\n",
    "        self.fc2 = nn.Linear(128, 64)\n",
    "        self.fc3 = nn.Linear(64, num_classes)\n",
    "        self.accuracy = torchmetrics.Accuracy(task=\"multiclass\", num_classes=num_classes)\n",
    "        self.f1score = torchmetrics.F1Score(task=\"multiclass\", num_classes=num_classes)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = nn.functional.relu(self.fc1(x))\n",
    "        x = nn.functional.relu(self.fc2(x))\n",
    "        x = self.fc3(x)\n",
    "        return x\n",
    "    \n",
    "    def _forward_step(self, batch, batch_idx):\n",
    "        x, y = batch\n",
    "        x = x.view(x.size(0), -1)\n",
    "        scores = self.forward(x)\n",
    "        loss = nn.functional.cross_entropy(scores, y)\n",
    "        return loss, scores, y\n",
    "    \n",
    "    def training_step(self, batch, batch_idx):\n",
    "        loss, scores, y = self._forward_step(batch, batch_idx)\n",
    "        accuracy = self.accuracy(scores, y)\n",
    "        f1score = self.f1score(scores, y)\n",
    "        self.log_dict({\"train_loss\": loss, \"train_accuracy\": accuracy,\n",
    "                       \"train_f1_score\": f1score}, on_epoch=True, prog_bar=True)\n",
    "        if batch_idx % 10 == 0:\n",
    "            images, _ = batch\n",
    "            indices = torch.randperm(len(images))[:8]\n",
    "            images = images[indices]\n",
    "            grid = torchvision.utils.make_grid(images.view(-1, 1, 28, 28))\n",
    "            self.logger.experiment.add_image(f\"MNIST Input Image - {batch_idx}\", grid, 0)\n",
    "        \n",
    "        return loss\n",
    "    \n",
    "    def validation_step(self, batch, batch_idx):\n",
    "        loss, scores, y = self._forward_step(batch, batch_idx)\n",
    "        self.log(f\"Validation Loss {batch_idx}\", loss)\n",
    "        return loss\n",
    "    \n",
    "    def test_step(self, batch, batch_idx):\n",
    "        loss, scores, y = self._forward_step(batch, batch_idx)\n",
    "        return loss\n",
    "    \n",
    "    def configure_optimizers(self):\n",
    "        return optim.Adam(self.parameters(), lr=0.001)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c545f38",
   "metadata": {},
   "source": [
    "### Loading Traning and Validation Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "2779e5ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "my_transforms = T.Compose([T.ToTensor(), T.Lambda(torch.flatten)])\n",
    "train_data = torchvision.datasets.MNIST('./data/', train=True, transform=my_transforms, download=True)\n",
    "train_data_loader = torch.utils.data.DataLoader(train_data, batch_size=64, num_workers=1)\n",
    "val_data = torchvision.datasets.MNIST('./data/', train=False, transform=my_transforms, download=True)\n",
    "val_data_loader = torch.utils.data.DataLoader(val_data, batch_size=64, num_workers=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "163fbd01",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([64, 784]) torch.Size([64])\n",
      "torch.Size([64, 784]) torch.Size([64])\n"
     ]
    }
   ],
   "source": [
    "for x, y in train_data_loader:\n",
    "    print(x.shape, y.shape)\n",
    "    break\n",
    "for x, y in val_data_loader:\n",
    "    print(x.shape, y.shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "28ad658a",
   "metadata": {},
   "source": [
    "### Initializing Lightning Trainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "d8c545f6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: False, used: False\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    }
   ],
   "source": [
    "tb_logger = TensorBoardLogger(\"tb_logs\", name='mnist_model_v0')\n",
    "trainer = pl.Trainer(max_epochs=8, logger=tb_logger)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "6cd6ea4b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Missing logger folder: tb_logs/mnist_model_v0\n",
      "\n",
      "  | Name     | Type               | Params\n",
      "------------------------------------------------\n",
      "0 | fc1      | Linear             | 100 K \n",
      "1 | fc2      | Linear             | 8.3 K \n",
      "2 | fc3      | Linear             | 650   \n",
      "3 | accuracy | MulticlassAccuracy | 0     \n",
      "4 | f1score  | MulticlassF1Score  | 0     \n",
      "------------------------------------------------\n",
      "109 K     Trainable params\n",
      "0         Non-trainable params\n",
      "109 K     Total params\n",
      "0.438     Total estimated model params size (MB)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Sanity Checking: |                                                                                 | 0/? [00:0…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "913f198488e942d781d176a32f2f9566",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training: |                                                                                        | 0/? [00:0…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |                                                                                      | 0/? [00:0…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |                                                                                      | 0/? [00:0…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |                                                                                      | 0/? [00:0…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |                                                                                      | 0/? [00:0…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |                                                                                      | 0/? [00:0…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |                                                                                      | 0/? [00:0…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |                                                                                      | 0/? [00:0…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |                                                                                      | 0/? [00:0…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`Trainer.fit` stopped: `max_epochs=8` reached.\n"
     ]
    }
   ],
   "source": [
    "model = LightningModel(784, 10)\n",
    "trainer.fit(model, train_data_loader, val_data_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5591fb61",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "3b43b8cd",
   "metadata": {},
   "source": [
    "### Initializing Lightning Trainer with GPU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0042e9a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = pl.Trainer(accelerator='gpu', devices=1, max_epochs=8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7f9810f",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = LightningModel(784, 10)\n",
    "trainer.fit(model, train_dataloader=train_data_loader, val_dataloader=val_data_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7186e2c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "6a82a1c6",
   "metadata": {},
   "source": [
    "### Training with Lightning Data Module"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "b18a2c1c",
   "metadata": {},
   "outputs": [],
   "source": [
    "class DataModule(pl.LightningDataModule):\n",
    "  def __init__(self, data_dir, batch_size, num_workers):\n",
    "    super().__init__()\n",
    "    self.data_dir = data_dir\n",
    "    self.batch_size = batch_size\n",
    "    self.num_workers = num_workers\n",
    "    self.my_transforms = T.Compose([T.ToTensor(), T.Lambda(torch.flatten)])\n",
    "\n",
    "  def prepare_data(self):\n",
    "    torchvision.datasets.MNIST(self.data_dir, download=True, train=True)\n",
    "    torchvision.datasets.MNIST(self.data_dir, download=True, train=False)\n",
    "\n",
    "  def setup(self, stage):\n",
    "    entire_dataset = torchvision.datasets.MNIST(self.data_dir, train=True, download=False,\n",
    "                                                transform=self.my_transforms)\n",
    "    self.train_ds, self.test_ds = torch.utils.data.random_split(entire_dataset, [50_000, 10_000])\n",
    "\n",
    "    self.val_ds = torchvision.datasets.MNIST(self.data_dir, train=False, download=False,\n",
    "                                             transform=self.my_transforms)\n",
    "\n",
    "  def train_dataloader(self):\n",
    "    return torch.utils.data.DataLoader(self.train_ds, batch_size=self.batch_size,\n",
    "                                       num_workers=self.num_workers)\n",
    "    \n",
    "  def test_dataloader(self):\n",
    "    return torch.utils.data.DataLoader(self.test_ds, batch_size=self.batch_size,\n",
    "                                       num_workers=self.num_workers)\n",
    "  \n",
    "  def val_dataloader(self):\n",
    "    return torch.utils.data.DataLoader(self.val_ds, batch_size=self.batch_size,\n",
    "                                       num_workers=self.num_workers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "38e9e8dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "dm = DataModule('./', 64, 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "66c51d46",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: False, used: False\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    }
   ],
   "source": [
    "tb_logger = TensorBoardLogger(\"tb_logs\", name='mnist_dm_model_v0')\n",
    "trainer = pl.Trainer(max_epochs=8, logger=tb_logger)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "8ee3cff7",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "  | Name     | Type               | Params\n",
      "------------------------------------------------\n",
      "0 | fc1      | Linear             | 100 K \n",
      "1 | fc2      | Linear             | 8.3 K \n",
      "2 | fc3      | Linear             | 650   \n",
      "3 | accuracy | MulticlassAccuracy | 0     \n",
      "4 | f1score  | MulticlassF1Score  | 0     \n",
      "------------------------------------------------\n",
      "109 K     Trainable params\n",
      "0         Non-trainable params\n",
      "109 K     Total params\n",
      "0.438     Total estimated model params size (MB)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Sanity Checking: |                                                                                 | 0/? [00:0…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/atifadib/opt/anaconda3/envs/torch_env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:436: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "01706210a3154f3387f8b44c737c618c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training: |                                                                                        | 0/? [00:0…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |                                                                                      | 0/? [00:0…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |                                                                                      | 0/? [00:0…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |                                                                                      | 0/? [00:0…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |                                                                                      | 0/? [00:0…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |                                                                                      | 0/? [00:0…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |                                                                                      | 0/? [00:0…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |                                                                                      | 0/? [00:0…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |                                                                                      | 0/? [00:0…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`Trainer.fit` stopped: `max_epochs=8` reached.\n"
     ]
    }
   ],
   "source": [
    "model = LightningModel(784, 10)\n",
    "trainer.fit(model, dm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e2e7746",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torch",
   "language": "python",
   "name": "torch"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
